# A set of functions for fitting XPS peaks, using the DOS calculated by DFT.


import numpy as np
import matplotlib.pyplot as plt
import math
from scipy.fftpack import fft, ifft
from scipy.optimize import curve_fit

#Calculate absorption coefficient A for states with quantum number l, ignoring coupling matrix element
def calcAbsorption(energy_array, jdos_array, l, tolerance = 0.001):
    return np.array([ (jdos_array[i])/(2*(2*l+1)*(energy_array[i])) for i in range(0,len(jdos_array)) ]) 

# Calculate lorentzian profile
def lorentzian(x, x0, gamma):
    return 1/(2*math.pi) * (gamma)/((x-x0)**2 + (1/2*gamma)**2)

# Calculate lorentzian profile
def gaussian(x, x0, FWHM):
    sigma = FWHM/(2*np.sqrt(2*np.log(2)))
    return 1/math.sqrt(2*math.pi*sigma**2) * math.exp(-(x-x0)**2/(2*sigma**2))

# Calculate lorentzian profile
def voigt (x, x0, FWHM_L, FWHM_G):
    return np.convolve(lorentzian(x, x0, gamma), gaussian(x, x0, sigma), mode='full')

# Calculate JDOS 
def calcJDOS(occStates, unoccStates, energySpacing, fermiEnergyPos):
    jdos = energySpacing*np.convolve(np.flip(occStates[0:fermiEnergyPos+1],0),unoccStates[fermiEnergyPos:],mode='full')
    return jdos

# Calculate XPS peak before convolution
def calcPeak(absFunc,energyGrid,peakEnergyGrid,tGrid):
    tSpacing = tGrid[4]-tGrid[3]
    energySpacing = energyGrid[4]-energyGrid[3]
    innerIntegral = np.array([np.exp(np.trapz(np.array(absFunc/(energyGrid)*(np.exp(-1j*energyGrid*t)-1)*energySpacing ))) for t in tGrid])
    peak = (1/(2*math.pi))*np.array([ np.trapz( np.exp(1j*en*tGrid) * innerIntegral * tSpacing  ) for en in peakEnergyGrid ])
    return peak


import numpy as np
import matplotlib.pyplot as plt

def processDOS(energyGrid, dos, fermiLevel=0):
    """Determines occupied and unoccupied states"""
    numPoints = len(energyGrid)
    energySpacing = (np.max(energyGrid)-np.min(energyGrid))/energyGrid.shape[0]
    occStates = [ 0 if energyGrid[i] > fermiLevel else dos[i] for i in range(0,numPoints) ]
    unoccStates = [ 0 if energyGrid[i] <= fermiLevel else dos[i] for i in range(0,numPoints) ]
    
    return numPoints, energySpacing, occStates, unoccStates

def plotDOS(energyGrid, dos, occStates, unoccStates, scaled = False):
#     plt.subplot(1,typesOfDOS,whichDOS+1)
    plt.plot(energyGrid,dos,'k-',linewidth=6,alpha=0.7)
    plt.plot(energyGrid,occStates,linewidth=1.5)
    plt.fill_between(energyGrid,0,occStates,linewidth=1.5,alpha=0.7)
    plt.plot(energyGrid,unoccStates,'r',linewidth=1.5)
    plt.xlabel("$E-E_F$ [eV]")
    plt.ylabel("$DOS$ $[states$ $eV^{-1}]$")
    if scaled: plt.title('Scaled density of states')
    else: plt.title('Calculated density of states')
    plt.show()
    
def plotXPS(energyXPS, XPS):
    plt.plot(energyXPS, XPS)
    plt.xlabel("$Binding$ $Energy$ $[eV]$")
    plt.gca().invert_xaxis()
    plt.ylabel("$Intensity$ $[a.u.]$")
    plt.ylim(-0.05,1.05)
    plt.title('Experimental XPS Data')
    plt.show()
    
def outputPlots(energyGridJdos,jdos,convGrid,finalPeakPreconv,finalPeakNorm,convGridShifted,finalPeakScaled):
        
        plt.plot(energyGridJdos,jdos)
        plt.xlabel("$E-E_F$ [eV]")
        plt.ylabel("$JDOS$ $[a.u.]$")
        plt.show()
        
        plt.plot(convGrid,finalPeakPreconv)
        plt.xlabel("$Binding$ $Energy$ $[eV]$")
        plt.gca().invert_xaxis()
        plt.ylabel("$Height$ $Normalized$ $Intensity$ $[a.u.]$")
        plt.title("DFT+WW profile, pre-convolution")
        plt.show()
        
        plt.plot(convGrid,finalPeakNorm)
        plt.xlabel("$Binding$ $Energy$ $[eV]$")
        plt.xlim(-2,2)
        plt.gca().invert_xaxis()
        plt.ylabel("$Height$ $Normalized$ $Intensity$ $[a.u.]$")
        plt.ylim(-0.05,1.05)
        plt.title("DFT+WW profile, after convolution")
        plt.show()

        plt.plot(convGridShifted,finalPeakScaled)
        plt.xlabel("$Binding$ $Energy$ $[eV]$")
        plt.gca().invert_xaxis()
        plt.ylabel("$Intensity$ $[a.u.]$")
        plt.ylim(-0.05,1.05)
        plt.title("DFT+WW profile, after convolution, scaled")
        plt.show()

def plotXPSFit(energyXPS, XPS, convGridShifted,finalPeakScaled):
        plt.plot(convGridShifted,finalPeakScaled)
        plt.scatter(energyXPS, XPS, c="black")
        plt.xlabel("$Binding$ $Energy$ $[eV]$")
        plt.xlim(-2.2,1.8)
        plt.gca().invert_xaxis()
        plt.ylabel("$Intensity$ $[a.u.]$")
        plt.ylim(-0.05,1.05)
        plt.title("DFT+WW profile, after convolution, scaled")
        plt.show()
        
def dosToPeak(energyGrid,dos,paramsDict,fermiLevel=0,visualizeInput=True,visualizeOutputs=True,tRange=500,numt=6000,peakGridNum=2001,peakGridRange=20):
    """Calculations done here.
    paramsDict = {'bindingEnergy':0.0,'intensity':1,'lorentzianWidth':0.48,'gaussianWidth':0.62,'scaleDOS':0.0187}
    """
    bindingEnergy = paramsDict['bindingEnergy']
    intensity = paramsDict['intensity']
    lorentzianWidth = paramsDict['lorentzianWidth']
    gaussianWidth = paramsDict['gaussianWidth']
    scaleDOS = math.sqrt(paramsDict['scaleDOS'])
    
    numPoints, energySpacing, occStates, unoccStates = processDOS(energyGrid,dos, fermiLevel)
    
    if visualizeInput:
        plotDOS(energyGrid, dos, occStates, unoccStates)
        
    tGrid=np.linspace(-tRange/2,tRange,num=numt) #Defining the grid used internally in the peak calculation (peak can be sensitive to numerical values):
    
    """Defining the grid the final peak will come out on (peak is sensitive to numerical values):"""
    pgmin=-peakGridRange/2 
    pgmax=peakGridRange/2 
    pgnum=peakGridNum
    cutoff=numPoints - 1 
    peakEnergyGrid = np.linspace(pgmin,pgmax,num=pgnum) # Energy grid (x values) for peaks
    convGrid = np.linspace(pgmin,pgmax,num=(pgnum))
    convGridShifted = np.linspace(pgmin - bindingEnergy,pgmax - bindingEnergy,num=(pgnum))

    lo = np.array([ lorentzian(en,0,lorentzianWidth) for en in peakEnergyGrid ]) #Lorentzian peak for convolution
    ga = np.array([ gaussian(en,0,gaussianWidth) for en in peakEnergyGrid ]) # Gaussian peak for convolution
    vo = np.convolve(lo, ga, mode = "same") # Voigt peak for convolution
    
    dosScaled = np.multiply(dos, scaleDOS)
    occStatesScaled = np.multiply(occStates, scaleDOS)
    unoccStatesScaled = np.multiply(unoccStates, scaleDOS)
    fermiEnergyPos = np.argmin(np.abs(energyGrid-fermiLevel)) #Needed to get convolution right. Might not be necessary for real DOS?
    energyGridJdos = energyGrid-np.min(energyGrid)+0.00001 #Offset to avoid division by 0 later.
    jdos = calcJDOS(occStatesScaled,unoccStatesScaled,energySpacing,fermiEnergyPos)
    absFunc = calcAbsorption(energyGridJdos,jdos,0)
    peak = calcPeak(absFunc[0:cutoff],energyGridJdos[0:cutoff],peakEnergyGrid,tGrid)
    finalPeakPreconv = np.real(peak)
    finalPeak = np.convolve(finalPeakPreconv , vo, mode='same')
    finalPeakNorm = finalPeak / np.max(np.real(finalPeak))
    finalPeakScaled = finalPeakNorm * intensity
    
    if visualizeOutputs:
        plotDOS(energyGrid,dosScaled,occStatesScaled,unoccStatesScaled,scaled=True)
        outputPlots(energyGridJdos,jdos,convGrid,finalPeakPreconv,finalPeakNorm,convGridShifted,finalPeakScaled)
    
    return energyGrid, dos, occStates, unoccStates, dosScaled, occStatesScaled, unoccStatesScaled, energyGridJdos, jdos, convGrid, finalPeakPreconv, finalPeakNorm, convGridShifted, finalPeakScaled  #, 


def writeFiles(energyGrid, dos, occStates, unoccStates, dosScaled, occStatesScaled, unoccStatesScaled, energyGridJdos, jdos, convGrid, finalPeakPreconv, finalPeakNorm, convGridShifted, finalPeakScaled, paramsDict, name = 'DFT+WW'):
    name = "DFT+WW" #Sample name goes here
    
    nameDOS =  str(name) +  "_DOS.csv"
    nameJDOS =  str(name) +  "_JDOS.csv"
    nameXPS =  str(name) +  "_XPS.csv"
    nameINFO =  str(name) +  "_INFO.csv"
    
    bindingEnergy = paramsDict['bindingEnergy']
    intensity = paramsDict['intensity']
    lorentzianWidth = paramsDict['lorentzianWidth']
    gaussianWidth = paramsDict['gaussianWidth']
    scaleDOS = paramsDict['scaleDOS']
    
    """Data format:  E-EF,DOS_UNSCALED,OCCUPIED_UNSCALED,UNOCCUPIED_UNSCALED,DOS_SCALED,OCCUPIED_SCALED,UNOCCUPIED_SCALED"""
    np.savetxt(nameDOS, np.c_[energyGrid, dos, occStates, unoccStates, dosScaled, occStatesScaled, unoccStatesScaled], delimiter=",")
    
    """Data format:  E-EF,JDOS_SCALED"""
    np.savetxt(nameJDOS, np.c_[energyGridJdos, jdos], delimiter=",")
    
    """Data format: BINDINGENERGY,PEAK_PRECONVOLUTION,PEAK_NORMALIZED, BINDINGENERGY_SHIFTED,PEAK_SCALED"""
    np.savetxt(nameXPS, np.c_[convGrid, finalPeakPreconv, finalPeakNorm, convGridShifted, finalPeakScaled], delimiter=",")
    
    """Data format: BINDINGENERGY,INTENSITY,LORENTZIANWIDTH,GAUSSIANWIDTH,SCALEDOS"""
    np.savetxt(nameINFO, np.c_[bindingEnergy, intensity, lorentzianWidth, gaussianWidth, scaleDOS], delimiter=",")
    